// Copyright 2016 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package types

import (
	"fmt"
	"reflect"
	"testing"
	"time"

	. "github.com/pingcap/check"
	"github.com/pingcap/tidb/parser/mysql"
	"github.com/pingcap/tidb/sessionctx/stmtctx"
)

var _ = Suite(&testDatumSuite{})

type testDatumSuite struct {
}

func (ts *testDatumSuite) TestDatum(c *C) {
	values := []interface{}{
		int64(1),
		uint64(1),
		1.1,
		"abc",
		[]byte("abc"),
		[]int{1},
	}
	for _, val := range values {
		var d Datum
		d.SetMinNotNull()
		d.SetValue(val)
		x := d.GetValue()
		c.Assert(x, DeepEquals, val)
		d.SetCollation(d.Collation())
		c.Assert(d.Collation(), NotNil)
		c.Assert(d.Length(), Equals, int(d.length))
		c.Assert(fmt.Sprint(d), Equals, d.String())
	}
}

func testDatumToBool(c *C, in interface{}, res int) {
	datum := NewDatum(in)
	res64 := int64(res)
	sc := new(stmtctx.StatementContext)
	sc.IgnoreTruncate = true
	b, err := datum.ToBool(sc)
	c.Assert(err, IsNil)
	c.Assert(b, Equals, res64)
}

func (ts *testDatumSuite) TestToBool(c *C) {
	testDatumToBool(c, int(0), 0)
	testDatumToBool(c, int64(0), 0)
	testDatumToBool(c, uint64(0), 0)
	testDatumToBool(c, float32(0.1), 0)
	testDatumToBool(c, float64(0.1), 0)
	testDatumToBool(c, float64(0.5), 1)
	testDatumToBool(c, float64(0.499), 0)
	testDatumToBool(c, "", 0)
	testDatumToBool(c, "0.1", 0)
	testDatumToBool(c, []byte{}, 0)
	testDatumToBool(c, []byte("0.1"), 0)
}

func (ts *testDatumSuite) TestEqualDatums(c *C) {
	tests := []struct {
		a    []interface{}
		b    []interface{}
		same bool
	}{
		// Positive cases
		{[]interface{}{1}, []interface{}{1}, true},
		{[]interface{}{1, "aa"}, []interface{}{1, "aa"}, true},
		{[]interface{}{1, "aa", 1}, []interface{}{1, "aa", 1}, true},

		// negative cases
		{[]interface{}{1}, []interface{}{2}, false},
		{[]interface{}{1, "a"}, []interface{}{1, "aaaaaa"}, false},
		{[]interface{}{1, "aa", 3}, []interface{}{1, "aa", 2}, false},

		// Corner cases
		{[]interface{}{}, []interface{}{}, true},
		{[]interface{}{nil}, []interface{}{nil}, true},
		{[]interface{}{}, []interface{}{1}, false},
		{[]interface{}{1}, []interface{}{1, 1}, false},
		{[]interface{}{nil}, []interface{}{1}, false},
	}
	for _, tt := range tests {
		testEqualDatums(c, tt.a, tt.b, tt.same)
	}
}

func testEqualDatums(c *C, a []interface{}, b []interface{}, same bool) {
	sc := new(stmtctx.StatementContext)
	sc.IgnoreTruncate = true
	res, err := EqualDatums(sc, MakeDatums(a...), MakeDatums(b...))
	c.Assert(err, IsNil)
	c.Assert(res, Equals, same, Commentf("a: %v, b: %v", a, b))
}

func testDatumToInt64(c *C, val interface{}, expect int64) {
	d := NewDatum(val)
	sc := new(stmtctx.StatementContext)
	sc.IgnoreTruncate = true
	b, err := d.ToInt64(sc)
	c.Assert(err, IsNil)
	c.Assert(b, Equals, expect)
}

func (ts *testTypeConvertSuite) TestToInt64(c *C) {
	testDatumToInt64(c, "0", int64(0))
	testDatumToInt64(c, int(0), int64(0))
	testDatumToInt64(c, int64(0), int64(0))
	testDatumToInt64(c, uint64(0), int64(0))
	testDatumToInt64(c, float32(3.1), int64(3))
	testDatumToInt64(c, float64(3.1), int64(3))
}

func (ts *testTypeConvertSuite) TestToFloat32(c *C) {
	ft := NewFieldType(mysql.TypeFloat)
	var datum = NewFloat64Datum(281.37)
	sc := new(stmtctx.StatementContext)
	sc.IgnoreTruncate = true
	converted, err := datum.ConvertTo(sc, ft)
	c.Assert(err, IsNil)
	c.Assert(converted.Kind(), Equals, KindFloat32)
	c.Assert(converted.GetFloat32(), Equals, float32(281.37))

	datum.SetString("281.37")
	converted, err = datum.ConvertTo(sc, ft)
	c.Assert(err, IsNil)
	c.Assert(converted.Kind(), Equals, KindFloat32)
	c.Assert(converted.GetFloat32(), Equals, float32(281.37))

	ft = NewFieldType(mysql.TypeDouble)
	datum = NewFloat32Datum(281.37)
	converted, err = datum.ConvertTo(sc, ft)
	c.Assert(err, IsNil)
	c.Assert(converted.Kind(), Equals, KindFloat64)
	// Convert to float32 and convert back to float64, we will get a different value.
	c.Assert(converted.GetFloat64(), Not(Equals), 281.37)
	c.Assert(converted.GetFloat64(), Equals, datum.GetFloat64())
}

func (ts *testTypeConvertSuite) TestToFloat64(c *C) {
	testCases := []struct {
		d      Datum
		errMsg string
		result float64
	}{
		{NewDatum(float32(3.00)), "", 3.00},
		{NewDatum(float64(12345.678)), "", 12345.678},
		{NewDatum("12345.678"), "", 12345.678},
		{NewDatum([]byte("12345.678")), "", 12345.678},
		{NewDatum(int64(12345)), "", 12345},
		{NewDatum(uint64(123456)), "", 123456},
		{NewDatum(byte(123)), "cannot convert .*", 0},
	}

	sc := new(stmtctx.StatementContext)
	sc.IgnoreTruncate = true
	for _, t := range testCases {
		converted, err := t.d.ToFloat64(sc)
		if t.errMsg == "" {
			c.Assert(err, IsNil)
		} else {
			c.Assert(err, ErrorMatches, t.errMsg)
		}
		c.Assert(converted, Equals, t.result)
	}
}

func (ts *testDatumSuite) TestIsNull(c *C) {
	tests := []struct {
		data   interface{}
		isnull bool
	}{
		{nil, true},
		{0, false},
		{1, false},
		{1.1, false},
		{"string", false},
		{"", false},
	}
	for _, tt := range tests {
		testIsNull(c, tt.data, tt.isnull)
	}
}

func testIsNull(c *C, data interface{}, isnull bool) {
	d := NewDatum(data)
	c.Assert(d.IsNull(), Equals, isnull, Commentf("data: %v, isnull: %v", data, isnull))
}

func (ts *testDatumSuite) TestToBytes(c *C) {
	tests := []struct {
		a   Datum
		out []byte
	}{
		{NewIntDatum(1), []byte("1")},
		{NewFloat64Datum(1.23), []byte("1.23")},
		{NewStringDatum("abc"), []byte("abc")},
	}
	sc := new(stmtctx.StatementContext)
	sc.IgnoreTruncate = true
	for _, tt := range tests {
		bin, err := tt.a.ToBytes()
		c.Assert(err, IsNil)
		c.Assert(bin, BytesEquals, tt.out)
	}
}

func (ts *testDatumSuite) TestComputePlusAndMinus(c *C) {
	sc := &stmtctx.StatementContext{TimeZone: time.UTC}
	tests := []struct {
		a      Datum
		b      Datum
		plus   Datum
		minus  Datum
		hasErr bool
	}{
		{NewIntDatum(72), NewIntDatum(28), NewIntDatum(100), NewIntDatum(44), false},
		{NewIntDatum(72), NewUintDatum(28), NewIntDatum(100), NewIntDatum(44), false},
		{NewUintDatum(72), NewUintDatum(28), NewUintDatum(100), NewUintDatum(44), false},
		{NewUintDatum(72), NewIntDatum(28), NewUintDatum(100), NewUintDatum(44), false},
		{NewFloat64Datum(72.0), NewFloat64Datum(28.0), NewFloat64Datum(100.0), NewFloat64Datum(44.0), false},
		{NewIntDatum(72), NewFloat64Datum(42), Datum{}, Datum{}, true},
		{NewStringDatum("abcd"), NewIntDatum(42), Datum{}, Datum{}, true},
	}

	for ith, tt := range tests {
		got, err := ComputePlus(tt.a, tt.b)
		c.Assert(err != nil, Equals, tt.hasErr)
		v, err := got.CompareDatum(sc, &tt.plus)
		c.Assert(err, IsNil)
		c.Assert(v, Equals, 0, Commentf("%dth got:%#v, expect:%#v", ith, got, tt.plus))
	}
}

func (ts *testDatumSuite) TestCloneDatum(c *C) {
	var raw Datum
	raw.b = []byte("raw")
	raw.k = KindRaw
	tests := []Datum{
		NewIntDatum(72),
		NewUintDatum(72),
		NewStringDatum("abcd"),
		NewBytesDatum([]byte("abcd")),
		raw,
	}

	sc := new(stmtctx.StatementContext)
	sc.IgnoreTruncate = true
	for _, tt := range tests {
		tt1 := CloneDatum(tt)
		res, err := tt.CompareDatum(sc, &tt1)
		c.Assert(err, IsNil)
		c.Assert(res, Equals, 0)
		if tt.b != nil {
			c.Assert(&tt.b[0], Not(Equals), &tt1.b[0])
		}
	}
}

func prepareCompareDatums() ([]Datum, []Datum) {
	vals := make([]Datum, 0, 5)
	vals = append(vals, NewIntDatum(1))
	vals = append(vals, NewFloat64Datum(1.23))
	vals = append(vals, NewStringDatum("abcde"))

	vals1 := make([]Datum, 0, 5)
	vals1 = append(vals1, NewIntDatum(1))
	vals1 = append(vals1, NewFloat64Datum(1.23))
	vals1 = append(vals1, NewStringDatum("abcde"))
	return vals, vals1
}

func BenchmarkCompareDatum(b *testing.B) {
	vals, vals1 := prepareCompareDatums()
	sc := new(stmtctx.StatementContext)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		for j, v := range vals {
			v.CompareDatum(sc, &vals1[j])
		}
	}
}

func BenchmarkCompareDatumByReflect(b *testing.B) {
	vals, vals1 := prepareCompareDatums()
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		reflect.DeepEqual(vals, vals1)
	}
}
